s = input()
n = len(s)
res_dict = {}

for x in s:
    if res_dict.get(x,False):
        res_dict[x] += 1
    else:
        res_dict[x] = 1
res = 0    
for k in res_dict:
    res += (res_dict.get(k))**2

print(res)